#!/usr/bin/env python

# This tool performs various parsing operations on the skip *-extc.h files,
# producing lists of runtime entrypoints.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import argparse
import logging
import itertools
import os
import pipes
import re
import subprocess
import sys
import tempfile
try:
    from cStringIO import StringIO
except ImportError:
    from io import StringIO


if __package__:
    from . import common
else:
    import common

logger = logging.getLogger(os.path.basename(__file__))


def recursiveListDir(path):
    for dirname, _dirs, files in os.walk(path):
        for f in files:
            yield os.path.normpath(os.path.join(dirname, f))


def main():
    parser = argparse.ArgumentParser(
        parents=[common.commonArguments(needsBackend=False)])
    parser.add_argument('command',
                        help='''\
File to produce: preamble, dynamic-list, extc-refs.

preamble - Produces a preamble file (llvm assembly) which defines the runtime
           functions for use by the compiler's output.

dynamic-list - Produces a linker versions script referencing all the public
               extern C symbols in the runtime.

extc-refs - Produces a C++ file with an array listing of all the public extern C
            symbols in the runtime to force the linker to pull in the runtime.
''')
    parser.add_argument(
        '-o', '--output',
        help=('File to write.  The file will only be updated if it is '
              'changed.'))
    parser.add_argument('--compiler', default='clang')
    parser.add_argument('--ndebug', action='store_true')
    parser.add_argument(
        '--m32', action='store_true',
        help='Generate 32 bit version.')
    parser.add_argument('--runtime-path')
    parser.add_argument('--install_dir', help=argparse.SUPPRESS)
    args = common.parse_args(parser)

    if not args.runtime_path:
        args.runtime_path = os.path.join(common.source_dir,
                                         'runtime/native/include')

    files = {
        os.path.join(args.runtime_path, 'skip', f)
        for f in recursiveListDir(os.path.join(args.runtime_path, 'skip'))
        if f.endswith('-extc.h') and not f.startswith('.#')} | {
            os.path.join(args.runtime_path, 'skip/plugin-extc.h'),
    }

    entrypoints = []
    for fname in sorted(files):
        with open(fname, 'r') as f:
            data = f.read()

        # strip comments
        data = re.sub(r'//.*', '', data)
        data = re.sub(r'/\*.*?\*/', '', data, re.DOTALL)

        # extract lines which look like:
        # extern <ret type> SKIPC?_functionName(<params>);
        def fixupLine(line):
            line = line.replace('\n', ' ')
            line = re.sub(' +', ' ', line)
            m = re.match(r'^extern (.*)(SKIPC?_\w*)(\(.*);$', line)
            retDecl = m.group(1).strip()
            funName = m.group(2).strip()
            paramDecl = m.group(3).strip()
            return (fname, retDecl, funName, paramDecl)

        lines = map(
            fixupLine,
            re.findall(r"\bextern\s+[^\";]*?SKIPC?_.*?;", data, re.DOTALL))
        # (fname, return type, function name, parameters)
        entrypoints.extend(lines)

    outfile = StringIO()
    if args.command == 'extc-refs':
        dumpExtcRefs(args, outfile, entrypoints)
    elif args.command == 'preamble':
        dumpPreamble(outfile, args, entrypoints)
    else:
        print('Unknown command', args.command, file=sys.stderr)
        sys.exit(1)

    out = outfile.getvalue()
    if not args.output:
        print(out)
        return 0

    data = None
    if os.path.exists(args.output):
        with open(args.output, 'r') as f:
            data = f.read().strip()
    if data != out:
        with open(args.output, 'w') as f:
            f.write(out)
            f.write('\n')
    return 0


def computeEntrypointByFile(entrypoints):
    entrypointsByFile = {}
    for entrypoint in entrypoints:
        filename = entrypoint[0]
        funcname = entrypoint[2]
        if filename not in entrypointsByFile:
            entrypointsByFile[filename] = set()
        entrypointsByFile[filename].add(funcname)
    return entrypointsByFile


def dumpExtcRefs(args, outfile, entrypoints):
    print('// This file is generated with the command:', file=outfile)
    print('//', ' '.join(map(pipes.quote, sys.argv)), file=outfile)
    print(file=outfile)

    entrypointsByFile = computeEntrypointByFile(entrypoints)

    for f in sorted(entrypointsByFile.keys()):
        f = os.path.relpath(f, args.runtime_path)
        print('#include "%s"' % (f,), file=outfile)

    print(file=outfile)

    print("void* _skip_references[] = {", file=outfile)

    for k, v in sorted(entrypointsByFile.items()):
        print(file=outfile)
        print("  /* %s */" % (os.path.basename(k),), file=outfile)
        for entry in sorted(v):
            print("  (void*)%s," % (entry,), file=outfile)

    print("};", file=outfile)


# generate gen_preamble.XXXXXX.cpp that includes -extc.h headers, plus an
# array containing references to the SKIPC?_xxx() runtime functions.
# Then compile it and return the .ll produced by clang.
def genDeclarations(args, entrypoints):

    if args.ndebug:
        ndebug = args.ndebug
    else:
        ndebug = False

    with tempfile.NamedTemporaryFile(prefix='gen_preamble.',
                                     suffix='.cpp',
                                     mode='w',
                                     delete=not args.keep_temp) as tmpSource:
        if args.keep_temp:
            logger.debug('(KEEPING TEMP FILE: %s)', tmpSource.name)

        for f in sorted({x[0] for x in entrypoints}):
            print('#include "%s"' % (f,), file=tmpSource)

        print('void* HACK_REFERENCES[] = {', file=tmpSource)
        for d in entrypoints:
            print('  (void*)&%s,' % (d[2],), file=tmpSource)

        print('};', file=tmpSource)

        tmpSource.flush()

        result = subprocess.check_output(
            (args.compiler,
             '-I', args.runtime_path,
             '-DGEN_PREAMBLE=1',
             '-D_FORTIFY_SOURCE=0',  # prevent memset/memcpy_chk
             '-S',
             '-c',
             '-emit-llvm',
             '-o', '-',
             '-x', 'c++', '-std=c++14',
             '-fPIC',
             '-fdeclspec',) +
            (('-m32',) if args.m32 else ()) +
            (('-DNDEBUG',) if ndebug else ('-DDEBUG',)) +
            (tmpSource.name,)
        ).decode('ascii')

    return result


def dumpHeader(outfile):
    print('''
;
; LLVM builtins
;
declare void @llvm.lifetime.start(i64, i8* nocapture) argmemonly nounwind
declare void @llvm.lifetime.end(i64, i8* nocapture) argmemonly nounwind

declare i32 @llvm.eh.typeid.for(i8*) nounwind readnone

declare i64 @llvm.ctlz.i64(i64, i1)
declare i64 @llvm.cttz.i64(i64, i1)
declare i64 @llvm.ctpop.i64(i64)

declare double @llvm.sin.f64(double %Val)
declare double @llvm.cos.f64(double %Val)
declare double @llvm.floor.f64(double %Val)
declare double @llvm.ceil.f64(double %Val)
declare double @llvm.round.f64(double %Val)
declare double @llvm.sqrt.f64(double %Val)
declare double @llvm.pow.f64(double %Val, double %Power)

declare i32 @__gxx_personality_v0(...)
declare i8* @__cxa_begin_catch(i8*)
declare void @__cxa_end_catch()

@_ZTIN4skip13SkipExceptionE = external constant { i8*, i8*, i8* }, align 8

; Delete after update_lkg
declare void @abort() noreturn

''', file=outfile)


def dumpPreamble(outfile, args, entrypoints):
    print('; This file is generated with the command:', file=outfile)
    print(';', ' '.join(map(pipes.quote, sys.argv)), file=outfile)
    print(file=outfile)

    result = genDeclarations(args, entrypoints)

    entrypointsByName = {}
    for entrypoint in entrypoints:
        filename = entrypoint[0]
        name = entrypoint[2]
        entrypointsByName[name] = filename

    # don't warn about these decls if we see them
    declWhitelist = [
        'llvm.memcpy.p0i8.p0i8.i64',
        'llvm.memset.p0i8.i64',
        'memmem'
    ]
    declRE = '|'.join(r'\b%s\b' % re.escape(s) for s in declWhitelist)

    # Find and remove the HACK_REFERENCES line and a handful of !metadata lines
    declsByFile = {x: [] for x in {x[0] for x in entrypoints}}
    lastBlank = True
    for line in result.splitlines():
        orig_line = line
        line = line.strip()

        if not line:
            continue

        if ((line.startswith('@HACK_REFERENCES = ') or
             line.startswith('!llvm.ident = ') or
             line.startswith('!llvm.module.flags = ') or
             line.startswith('; ModuleID =') or
             line.startswith('source_filename =') or
             re.match(r'!\d+ = ', line))):
            continue

        elif line.startswith('declare ') and '@' in line:
            m = re.search(r'@(SKIPC?_\w+)', line)
            if m:
                entrypoint = m.group(1)
                if entrypoint not in entrypointsByName:
                    continue
                # save the SKIP decls to emit below
                declsByFile[entrypointsByName[entrypoint]].append(line)
                if '(...)' in line:
                    print("ERROR: Make sure that functions which take no " +
                          "parameters are declared as '%s(void)'" % entrypoint,
                          file=sys.stderr)
                    raise RuntimeError("incorrect void function")
                continue

            # whitelist certain functions.
            m = re.search(r'@(llvm[\w.]+|' + declRE + ')', line)
            if not m:
                print("Warning: unexpected declare: ", line)
                continue
            if not re.match(declRE, m.group(1)):
                # if you get here, update declWhitelist.
                print("Warning: unexpected llvm declare: ", line)
            print(orig_line, file=outfile)
            lastBlank = False
            continue

        elif not line:
            if not lastBlank:
                print(file=outfile)
                lastBlank = True

        else:
            # Ensure that no llvm unnamed metadata references snuck in (since we
            # strip the unnamed metadata definitions above).
            assert not re.search(r'!\d+', line)
            print(orig_line, file=outfile)
            lastBlank = False

    if not lastBlank:
        print(file=outfile)

    dumpHeader(outfile)

    # now emit the SKIP decls we scraped earlier
    for fname, decls in sorted(declsByFile.items()):
        def declKey(n):
            m = re.match(r'@(\w+)', n)
            if m:
                return m.group(1) + n
            else:
                return n
        decl = sorted(decls, key=declKey)

        print(';', file=outfile)
        print('; %s' % (os.path.relpath(fname, common.source_dir),),
              file=outfile)
        print(';', file=outfile)
        print("\n".join(decl), file=outfile)
        print(file=outfile)

    print("%struct._FunctionSignature = type { i8*, void (...)*, i8, i8, i8* }",
          file=outfile)

    print('; -------- END OF PREAMBLE --------', file=outfile)


if __name__ == '__main__':
    rc = main()
    sys.exit(rc)
